{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp data.metadatasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Metadataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">A dataset of datasets\n",
    "\n",
    "This functionality will allow you to create a dataset from data stores in multiple, smaller datasets.\n",
    "\n",
    "I'd like to thank both **Thomas Capelle** (https://github.com/tcapelle)  and **Xander Dunn** (https://github.com/xanderdunn) for their contributions to make this code possible. \n",
    "\n",
    "This functionality allows you to use multiple numpy arrays instead of a single one, which may be very useful in many practical settings. It's been tested it with 10k+ datasets and it works well. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from tsai.utils import *\n",
    "from tsai.data.validation import *\n",
    "from tsai.data.core import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSMetaDataset():\n",
    "    _type = (TSTensor,)\n",
    "    \" A dataset capable of indexing mutiple datasets at the same time\"\n",
    "    def __init__(self, dataset_list, **kwargs):\n",
    "        if not is_listy(dataset_list): dataset_list = [dataset_list]\n",
    "        self.datasets = dataset_list\n",
    "        self.split = kwargs['split'] if 'split' in kwargs else None            \n",
    "        self.mapping = self._mapping()\n",
    "        if hasattr(dataset_list[0], 'loss_func'): \n",
    "            self.loss_func =  dataset_list[0].loss_func\n",
    "        else: \n",
    "            self.loss_func = None\n",
    "\n",
    "    def __len__(self):\n",
    "        if self.split is not None: \n",
    "            return len(self.split)\n",
    "        else:\n",
    "            return sum([len(ds) for ds in self.datasets])\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        if self.datasets:\n",
    "            if self.split is not None: idx = self.split[idx]\n",
    "            idx = listify(idx)\n",
    "            idxs = self.mapping[idx]\n",
    "            idxs = idxs[idxs[:, 0].argsort()]\n",
    "            self.mapping_idxs = idxs\n",
    "            ds = np.unique(idxs[:, 0])\n",
    "            b = [self.datasets[d][idxs[idxs[:, 0] == d, 1]] for d in ds]\n",
    "            output = tuple(map(torch.cat, zip(*b)))\n",
    "            output = self._type[0](output[0]), output[1]\n",
    "            return output\n",
    "        else:\n",
    "            return\n",
    "\n",
    "    def _mapping(self):\n",
    "        lengths = [len(ds) for ds in self.datasets]\n",
    "        idx_pairs = np.zeros((np.sum(lengths), 2)).astype(np.int32)\n",
    "        start = 0\n",
    "        for i,length in enumerate(lengths):\n",
    "            if i > 0: \n",
    "                idx_pairs[start:start+length, 0] = i\n",
    "            idx_pairs[start:start+length, 1] = np.arange(length)\n",
    "            start += length\n",
    "        return idx_pairs\n",
    "\n",
    "    def new_empty(self): \n",
    "        new_dset = type(self)(self.datasets, split=self.split)\n",
    "        new_dset.datasets = None\n",
    "        return new_dset\n",
    "    \n",
    "    @property\n",
    "    def vars(self):\n",
    "        s = self.datasets[0][0][0] if not isinstance(self.datasets[0][0][0], tuple) else self.datasets[0][0][0][0]\n",
    "        return s.shape[-2]\n",
    "    @property\n",
    "    def len(self): \n",
    "        s = self.datasets[0][0][0] if not isinstance(self.datasets[0][0][0], tuple) else self.datasets[0][0][0][0]\n",
    "        return s.shape[-1]\n",
    "    @property\n",
    "    def vocab(self): \n",
    "        return self.datasets[0].vocab\n",
    "    @property\n",
    "    def cat(self): return hasattr(self, \"vocab\")\n",
    "\n",
    "\n",
    "class TSMetaDatasets(FilteredBase):\n",
    "    def __init__(self, metadataset, splits):\n",
    "        store_attr()\n",
    "        self.mapping = metadataset.mapping\n",
    "        self.datasets = metadataset.datasets\n",
    "    def subset(self, i):\n",
    "        return type(self.metadataset)(self.metadataset.datasets, split=self.splits[i])\n",
    "    @property\n",
    "    def train(self): \n",
    "        return self.subset(0)\n",
    "    @property\n",
    "    def valid(self): \n",
    "        return self.subset(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create 3 datasets. In this case they will have different sizes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TSTensor(samples:64, vars:5, len:50, device=cpu, dtype=torch.float32),\n",
       " TensorCategory([1, 0, 3, 9, 7, 2, 8, 6, 1, 1, 1, 8, 1, 1, 9, 2, 6, 6, 1, 5, 5,\n",
       "                 6, 9, 2, 7, 1, 6, 4, 9, 2, 5, 0, 4, 9, 1, 4, 4, 6, 0, 8, 8, 5,\n",
       "                 8, 6, 9, 0, 8, 8, 6, 4, 8, 9, 7, 3, 4, 7, 7, 8, 6, 2, 3, 0, 7,\n",
       "                 4]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab = alphabet[:10]\n",
    "dsets = []\n",
    "for i in range(3):\n",
    "    size = np.random.randint(50, 150)\n",
    "    X = torch.rand(size, 5, 50)\n",
    "    y = vocab[torch.randint(0, 10, (size,))]\n",
    "    tfms = [None, TSClassification(vocab=vocab)]\n",
    "    dset = TSDatasets(X, y, tfms=tfms)\n",
    "    dsets.append(dset)\n",
    "\n",
    "\n",
    "\n",
    "metadataset = TSMetaDataset(dsets)\n",
    "splits = TimeSplitter(show_plot=False)(metadataset)\n",
    "metadatasets = TSMetaDatasets(metadataset, splits=splits)\n",
    "dls = TSDataLoaders.from_dsets(metadatasets.train, metadatasets.valid)\n",
    "xb, yb = dls.train.one_batch()\n",
    "xb, yb"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can train metadatasets as you would train any other time series model in `tsai`:\n",
    "\n",
    "```python\n",
    "learn = ts_learner(dls, arch=\"TSTPlus\")\n",
    "learn.fit_one_cycle(1)\n",
    "learn.export(\"test.pkl\")\n",
    "```"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For inference, you should create the new metadatasets using the same method you used when you trained it. The you use fastai's learn.get_preds method to generate predictions: \n",
    "\n",
    "```python\n",
    "vocab = alphabet[:10]\n",
    "dsets = []\n",
    "for i in range(3):\n",
    "    size = np.random.randint(50, 150)\n",
    "    X = torch.rand(size, 5, 50)\n",
    "    y = vocab[torch.randint(0, 10, (size,))]\n",
    "    tfms = [None, TSClassification(vocab=vocab)]\n",
    "    dset = TSDatasets(X, y, tfms=tfms)\n",
    "    dsets.append(dset)\n",
    "metadataset = TSMetaDataset(dsets)\n",
    "dl = TSDataLoader(metadataset)\n",
    "\n",
    "\n",
    "learn = load_learner(\"test.pkl\")\n",
    "learn.get_preds(dl=dl)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There also en easy way to map any particular sample in a batch to the original dataset and id: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = TSDataLoaders.from_dsets(metadatasets.train, metadatasets.valid)\n",
    "xb, yb = first(dls.train)\n",
    "mappings = dls.train.dataset.mapping_idxs\n",
    "for i, (xbi, ybi) in enumerate(zip(xb, yb)):\n",
    "    ds, idx = mappings[i]\n",
    "    test_close(dsets[ds][idx][0].data.cpu(), xbi.cpu())\n",
    "    test_close(dsets[ds][idx][1].data.cpu(), ybi.cpu())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For example the 3rd sample in this batch would be: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  0, 112], dtype=int32)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dls.train.dataset.mapping_idxs[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/008_data.metadatasets.ipynb saved at 2023-03-24 11:30:57\n",
      "Correct notebook to script conversion! 😃\n",
      "Friday 24/03/23 11:31:00 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
